{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import keras\n",
    "import math\n",
    "import numpy as np\n",
    "import os\n",
    "import sklearn.metrics as skm\n",
    "import sys\n",
    "sys.path.append(\"../../../ecg\")\n",
    "\n",
    "import load\n",
    "import util\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8761/8761 [00:03<00:00, 2561.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8761/8761 [==============================] - 20s 2ms/step\n"
     ]
    }
   ],
   "source": [
    "model_path = \"/deep/group/awni/ecg_models/default/1527627404-9/0.337-0.880-012-0.255-0.906.hdf5\"\n",
    "data_json = \"../dev.json\"\n",
    "\n",
    "preproc = util.load(os.path.dirname(model_path))\n",
    "dataset = load.load_dataset(data_json)\n",
    "ecgs, labels = preproc.process(*dataset)\n",
    "\n",
    "model = keras.models.load_model(model_path)\n",
    "probs = model.predict(ecgs, verbose=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def stats(ground_truth, preds):\n",
    "    labels = range(ground_truth.shape[2])\n",
    "    g = np.argmax(ground_truth, axis=2).ravel()\n",
    "    p = np.argmax(preds, axis=2).ravel()\n",
    "    stat_dict = {}\n",
    "    for i in labels:\n",
    "        # compute all the stats for each label\n",
    "        tp = np.sum(g[g==i] == p[g==i])\n",
    "        fp = np.sum(g[p==i] != p[p==i])\n",
    "        fn = np.sum(g==i) - tp\n",
    "        tn = np.sum(g!=i) - fp\n",
    "        stat_dict[i] = (tp, fp, fn, tn)\n",
    "    return stat_dict\n",
    "\n",
    "def to_set(preds):\n",
    "    idxs = np.argmax(preds, axis=2)\n",
    "    return [list(set(r)) for r in idxs]\n",
    "\n",
    "def set_stats(ground_truth, preds):\n",
    "    labels = range(ground_truth.shape[2])\n",
    "    ground_truth = to_set(ground_truth)\n",
    "    preds = to_set(preds)\n",
    "    stat_dict = {}\n",
    "    for x in labels:\n",
    "        tp = 0; fp = 0; fn = 0; tn = 0;\n",
    "        for g, p in zip(ground_truth, preds):\n",
    "            if x in g and x in p: # tp\n",
    "                tp += 1\n",
    "            if x not in g and x in p: # fp\n",
    "                fp += 1\n",
    "            if x in g and x not in p:\n",
    "                fn += 1\n",
    "            if x not in g and x not in p:\n",
    "                tn += 1\n",
    "        stat_dict[x] = (tp, fp, fn, tn)\n",
    "    return stat_dict\n",
    "\n",
    "def compute_f1(tp, fp, fn, tn):\n",
    "    precision = tp / float(tp + fp)\n",
    "    recall = tp / float(tp + fn)\n",
    "    specificity = tn / float(tn + fp)\n",
    "    npv = tn / float(tn + fn)\n",
    "    f1 = 2 * precision * recall / (precision + recall)\n",
    "    return f1, tp + fn\n",
    "\n",
    "def print_results(seq_sd, set_sd):\n",
    "    print \"\\t\\t Seq F1    Set F1\"\n",
    "    seq_tf1 = 0; seq_tot = 0\n",
    "    set_tf1 = 0; set_tot = 0\n",
    "    for k, v in seq_sd.items():\n",
    "        set_f1, n = compute_f1(*set_sd[k])\n",
    "        set_tf1 += n * set_f1\n",
    "        set_tot += n\n",
    "        seq_f1, n = compute_f1(*v)\n",
    "        seq_tf1 += n * seq_f1\n",
    "        seq_tot += n\n",
    "        print \"{:>10} {:10.3f} {:10.3f}\".format(\n",
    "            preproc.classes[k], seq_f1, set_f1)\n",
    "    print \"{:>10} {:10.3f} {:10.3f}\".format(\n",
    "        \"Average\", seq_tf1 / float(seq_tot), set_tf1 / float(set_tot))\n",
    "    \n",
    "def c_statistic_with_95p_confidence_interval(cstat, num_positives, num_negatives, z_alpha_2=1.96):\n",
    "    \"\"\"\n",
    "    Calculates the confidence interval of an ROC curve (c-statistic), using the method described\n",
    "    under \"Confidence Interval for AUC\" here:\n",
    "      https://ncss-wpengine.netdna-ssl.com/wp-content/themes/ncss/pdf/Procedures/PASS/Confidence_Intervals_for_the_Area_Under_an_ROC_Curve.pdf\n",
    "    Args:\n",
    "        cstat: the c-statistic (equivalent to area under the ROC curve)\n",
    "        num_positives: number of positive examples in the set.\n",
    "        num_negatives: number of negative examples in the set.\n",
    "        z_alpha_2 (optional): the critical value for an N% confidence interval, e.g., 1.96 for 95%,\n",
    "            2.326 for 98%, 2.576 for 99%, etc.\n",
    "    Returns:\n",
    "        The 95% confidence interval half-width, e.g., the Y in X ± Y.\n",
    "    \"\"\"\n",
    "    q1 = cstat / (2 - cstat)\n",
    "    q2 = 2 * cstat**2 / (1 + cstat)\n",
    "    numerator = cstat * (1 - cstat) \\\n",
    "        + (num_positives - 1) * (q1 - cstat**2) \\\n",
    "        + (num_negatives - 1) * (q2 - cstat**2)\n",
    "    standard_error_auc = math.sqrt(numerator / (num_positives * num_negatives))\n",
    "    return z_alpha_2 * standard_error_auc\n",
    "\n",
    "def roc_auc(ground_truth, probs, index):\n",
    "    gts = np.argmax(ground_truth, axis=2)\n",
    "    n_gts = np.zeros_like(gts)\n",
    "    n_gts[gts==index] = 1\n",
    "    n_pos = np.sum(n_gts == 1)\n",
    "    n_neg = n_gts.size - n_pos\n",
    "    n_ps = probs[..., index].squeeze()\n",
    "    n_gts, n_ps = n_gts.ravel(), n_ps.ravel()\n",
    "    return n_pos, n_neg, skm.roc_auc_score(n_gts, n_ps)\n",
    "\n",
    "def roc_auc_set(ground_truth, probs, index):\n",
    "    gts = np.argmax(ground_truth, axis=2)\n",
    "    max_ps = np.max(probs[...,index], axis=1)\n",
    "    max_gts = np.any(gts==index, axis=1)\n",
    "    pos = np.sum(max_gts)\n",
    "    neg = max_gts.size - pos\n",
    "    return pos, neg, skm.roc_auc_score(max_gts, max_ps)\n",
    "\n",
    "def print_aucs(ground_truth, probs):\n",
    "    seq_tauc = 0.0; seq_tot = 0.0\n",
    "    set_tauc = 0.0; set_tot = 0.0\n",
    "    print \"\\t        AUC\"\n",
    "    for idx, cname in preproc.int_to_class.items():\n",
    "        pos, neg, seq_auc = roc_auc(ground_truth, probs, idx)\n",
    "        seq_tot += pos\n",
    "        seq_tauc += pos * seq_auc\n",
    "        seq_conf = c_statistic_with_95p_confidence_interval(seq_auc, pos, neg)\n",
    "        pos, neg, set_auc = roc_auc_set(ground_truth, probs, idx)\n",
    "        set_tot += pos\n",
    "        set_tauc += pos * set_auc\n",
    "        set_conf = c_statistic_with_95p_confidence_interval(set_auc, pos, neg)\n",
    "        print \"{: <8}\\t{:.3f} ({:.3f}-{:.3f})\\t{:.3f} ({:.3f}-{:.3f})\".format(\n",
    "            cname, seq_auc, seq_auc-seq_conf,seq_auc+seq_conf,\n",
    "            set_auc, set_auc-set_conf, set_auc+set_conf)\n",
    "    print \"Average\\t\\t{:.3f}\\t{:.3f}\".format(seq_tauc/seq_tot, set_tauc/set_tot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\t\t Seq F1    Set F1\n",
      "        AF      0.914      0.914\n",
      "       AVB      0.805      0.839\n",
      "  BIGEMINY      0.917      0.896\n",
      "       EAR      0.652      0.699\n",
      "       IVR      0.721      0.758\n",
      "JUNCTIONAL      0.706      0.740\n",
      "     NOISE      0.911      0.847\n",
      "     SINUS      0.920      0.960\n",
      "       SVT      0.700      0.812\n",
      " TRIGEMINY      0.924      0.918\n",
      "        VT      0.769      0.848\n",
      "WENCKEBACH      0.779      0.822\n",
      "   Average      0.879      0.889\n",
      "\t        AUC\n",
      "AF      \t0.994 (0.994-0.995)\t0.994 (0.991-0.996)\n",
      "AVB     \t0.992 (0.990-0.993)\t0.990 (0.985-0.995)\n",
      "BIGEMINY\t0.999 (0.998-1.000)\t0.998 (0.994-1.001)\n",
      "EAR     \t0.977 (0.975-0.980)\t0.967 (0.957-0.977)\n",
      "IVR     \t0.996 (0.994-0.998)\t0.991 (0.984-0.998)\n",
      "JUNCTIONAL\t0.987 (0.985-0.989)\t0.984 (0.976-0.992)\n",
      "NOISE   \t0.994 (0.993-0.994)\t0.978 (0.973-0.984)\n",
      "SINUS   \t0.979 (0.979-0.980)\t0.987 (0.985-0.989)\n",
      "SVT     \t0.986 (0.984-0.988)\t0.983 (0.977-0.989)\n",
      "TRIGEMINY\t0.999 (0.999-1.000)\t0.998 (0.994-1.001)\n",
      "VT      \t0.997 (0.995-0.998)\t0.992 (0.988-0.997)\n",
      "WENCKEBACH\t0.991 (0.990-0.993)\t0.990 (0.985-0.996)\n",
      "Average\t\t0.986\t0.987\n"
     ]
    }
   ],
   "source": [
    "print_results(stats(labels, probs), set_stats(labels, probs))\n",
    "print_aucs(labels, probs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
